# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from math import nan

from gpu.host import DeviceContext
from memory import bitcast


# CHECK-LABEL: test_e4m3fn_initialization
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_e4m3fn_initialization():
    print("== test_e4m3fn_initialization")

    var simd_e4m3fn = SIMD[DType.float8_e4m3fn, 256](
        0.0,
        0.001953125,
        0.00390625,
        0.005859375,
        0.0078125,
        0.009765625,
        0.01171875,
        0.013671875,
        0.015625,
        0.017578125,
        0.01953125,
        0.021484375,
        0.0234375,
        0.025390625,
        0.02734375,
        0.029296875,
        0.03125,
        0.03515625,
        0.0390625,
        0.04296875,
        0.046875,
        0.05078125,
        0.0546875,
        0.05859375,
        0.0625,
        0.0703125,
        0.078125,
        0.0859375,
        0.09375,
        0.1015625,
        0.109375,
        0.1171875,
        0.125,
        0.140625,
        0.15625,
        0.171875,
        0.1875,
        0.203125,
        0.21875,
        0.234375,
        0.25,
        0.28125,
        0.3125,
        0.34375,
        0.375,
        0.40625,
        0.4375,
        0.46875,
        0.5,
        0.5625,
        0.625,
        0.6875,
        0.75,
        0.8125,
        0.875,
        0.9375,
        1.0,
        1.125,
        1.25,
        1.375,
        1.5,
        1.625,
        1.75,
        1.875,
        2.0,
        2.25,
        2.5,
        2.75,
        3.0,
        3.25,
        3.5,
        3.75,
        4.0,
        4.5,
        5.0,
        5.5,
        6.0,
        6.5,
        7.0,
        7.5,
        8.0,
        9.0,
        10.0,
        11.0,
        12.0,
        13.0,
        14.0,
        15.0,
        16.0,
        18.0,
        20.0,
        22.0,
        24.0,
        26.0,
        28.0,
        30.0,
        32.0,
        36.0,
        40.0,
        44.0,
        48.0,
        52.0,
        56.0,
        60.0,
        64.0,
        72.0,
        80.0,
        88.0,
        96.0,
        104.0,
        112.0,
        120.0,
        128.0,
        144.0,
        160.0,
        176.0,
        192.0,
        208.0,
        224.0,
        240.0,
        256.0,
        288.0,
        320.0,
        352.0,
        384.0,
        416.0,
        448.0,
        nan[DType.float8_e4m3fn](),
        -0.0,
        -0.001953125,
        -0.00390625,
        -0.005859375,
        -0.0078125,
        -0.009765625,
        -0.01171875,
        -0.013671875,
        -0.015625,
        -0.017578125,
        -0.01953125,
        -0.021484375,
        -0.0234375,
        -0.025390625,
        -0.02734375,
        -0.029296875,
        -0.03125,
        -0.03515625,
        -0.0390625,
        -0.04296875,
        -0.046875,
        -0.05078125,
        -0.0546875,
        -0.05859375,
        -0.0625,
        -0.0703125,
        -0.078125,
        -0.0859375,
        -0.09375,
        -0.1015625,
        -0.109375,
        -0.1171875,
        -0.125,
        -0.140625,
        -0.15625,
        -0.171875,
        -0.1875,
        -0.203125,
        -0.21875,
        -0.234375,
        -0.25,
        -0.28125,
        -0.3125,
        -0.34375,
        -0.375,
        -0.40625,
        -0.4375,
        -0.46875,
        -0.5,
        -0.5625,
        -0.625,
        -0.6875,
        -0.75,
        -0.8125,
        -0.875,
        -0.9375,
        -1.0,
        -1.125,
        -1.25,
        -1.375,
        -1.5,
        -1.625,
        -1.75,
        -1.875,
        -2.0,
        -2.25,
        -2.5,
        -2.75,
        -3.0,
        -3.25,
        -3.5,
        -3.75,
        -4.0,
        -4.5,
        -5.0,
        -5.5,
        -6.0,
        -6.5,
        -7.0,
        -7.5,
        -8.0,
        -9.0,
        -10.0,
        -11.0,
        -12.0,
        -13.0,
        -14.0,
        -15.0,
        -16.0,
        -18.0,
        -20.0,
        -22.0,
        -24.0,
        -26.0,
        -28.0,
        -30.0,
        -32.0,
        -36.0,
        -40.0,
        -44.0,
        -48.0,
        -52.0,
        -56.0,
        -60.0,
        -64.0,
        -72.0,
        -80.0,
        -88.0,
        -96.0,
        -104.0,
        -112.0,
        -120.0,
        -128.0,
        -144.0,
        -160.0,
        -176.0,
        -192.0,
        -208.0,
        -224.0,
        -240.0,
        -256.0,
        -288.0,
        -320.0,
        -352.0,
        -384.0,
        -416.0,
        -448.0,
        nan[DType.float8_e4m3fn](),
    )

    for i in range(256):
        print(simd_e4m3fn[i], end=", ")
        if (i + 1) % 8 == 0:
            print("")


# CHECK-LABEL: test_simd_e4m3_to_f32
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_simd_e4m3_to_f32():
    print("== test_simd_e4m3_to_f32")

    var e4m3_simd = SIMD[DType.float8_e4m3fn, 256](0.0)

    for i in range(256):
        e4m3_simd[i] = bitcast[DType.float8_e4m3fn](UInt8(i))

    e4m3_casted_f32 = e4m3_simd.cast[DType.float32]()

    comptime M = 32
    comptime N = 8
    for i in range(M):
        for j in range(N):
            print(e4m3_casted_f32[i * N + j], end=", ")
        print("")


# CHECK-LABEL: test_simd_e4m3_to_f16
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_simd_e4m3_to_f16():
    print("== test_simd_e4m3_to_f16")

    var e4m3_simd = SIMD[DType.float8_e4m3fn, 256](0.0)

    for i in range(256):
        e4m3_simd[i] = bitcast[DType.float8_e4m3fn](UInt8(i))

    e4m3_casted_f16 = e4m3_simd.cast[DType.float16]()

    comptime M = 32
    comptime N = 8
    for i in range(M):
        for j in range(N):
            print(e4m3_casted_f16[i * N + j], end=", ")
        print("")


# CHECK-LABEL: test_simd_e4m3_to_bf16
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_simd_e4m3_to_bf16():
    print("== test_simd_e4m3_to_bf16")

    var e4m3_simd = SIMD[DType.float8_e4m3fn, 256](0.0)

    for i in range(256):
        e4m3_simd[i] = bitcast[DType.float8_e4m3fn](UInt8(i))

    e4m3_casted_bf16 = e4m3_simd.cast[DType.bfloat16]()

    comptime M = 32
    comptime N = 8
    for i in range(M):
        for j in range(N):
            print(e4m3_casted_bf16[i * N + j], end=", ")
        print("")


# CHECK-LABEL: test_simd_f32_to_e4m3
# CHECK: -256.0, -256.0, -256.0, -256.0, -256.0, -256.0, -256.0, -256.0,
# CHECK: -256.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0,
# CHECK: -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0,
# CHECK: -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0,
# CHECK: -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0,
# CHECK: -224.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0,
# CHECK: -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0,
# CHECK: -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0,
# CHECK: -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0,
# CHECK: -192.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0,
# CHECK: -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0,
# CHECK: -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0,
# CHECK: -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0,
# CHECK: -160.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0,
# CHECK: -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0,
# CHECK: -128.0, -128.0, -128.0, -128.0, -128.0, -128.0, -128.0, -128.0,
# CHECK: -128.0, -128.0, -128.0, -128.0, -128.0, -120.0, -120.0, -120.0,
# CHECK: -120.0, -120.0, -120.0, -120.0, -112.0, -112.0, -112.0, -112.0,
# CHECK: -112.0, -112.0, -112.0, -112.0, -112.0, -104.0, -104.0, -104.0,
# CHECK: -104.0, -104.0, -104.0, -104.0, -96.0, -96.0, -96.0, -96.0,
# CHECK: -96.0, -96.0, -96.0, -96.0, -96.0, -88.0, -88.0, -88.0,
# CHECK: -88.0, -88.0, -88.0, -88.0, -80.0, -80.0, -80.0, -80.0,
# CHECK: -80.0, -80.0, -80.0, -80.0, -80.0, -72.0, -72.0, -72.0,
# CHECK: -72.0, -72.0, -72.0, -72.0, -64.0, -64.0, -64.0, -64.0,
# CHECK: -64.0, -64.0, -64.0, -60.0, -60.0, -60.0, -56.0, -56.0,
# CHECK: -56.0, -56.0, -56.0, -52.0, -52.0, -52.0, -48.0, -48.0,
# CHECK: -48.0, -48.0, -48.0, -44.0, -44.0, -44.0, -40.0, -40.0,
# CHECK: -40.0, -40.0, -40.0, -36.0, -36.0, -36.0, -32.0, -32.0,
# CHECK: -32.0, -32.0, -30.0, -28.0, -28.0, -28.0, -26.0, -24.0,
# CHECK: -24.0, -24.0, -22.0, -20.0, -20.0, -20.0, -18.0, -16.0,
# CHECK: -16.0, -15.0, -14.0, -13.0, -12.0, -11.0, -10.0, -9.0,
# CHECK: -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0,
# CHECK: 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 16.0, 18.0, 20.0, 20.0, 20.0, 22.0, 24.0,
# CHECK: 24.0, 24.0, 26.0, 28.0, 28.0, 28.0, 30.0, 32.0,
# CHECK: 32.0, 32.0, 32.0, 36.0, 36.0, 36.0, 40.0, 40.0,
# CHECK: 40.0, 40.0, 40.0, 44.0, 44.0, 44.0, 48.0, 48.0,
# CHECK: 48.0, 48.0, 48.0, 52.0, 52.0, 52.0, 56.0, 56.0,
# CHECK: 56.0, 56.0, 56.0, 60.0, 60.0, 60.0, 64.0, 64.0,
# CHECK: 64.0, 64.0, 64.0, 64.0, 64.0, 72.0, 72.0, 72.0,
# CHECK: 72.0, 72.0, 72.0, 72.0, 80.0, 80.0, 80.0, 80.0,
# CHECK: 80.0, 80.0, 80.0, 80.0, 80.0, 88.0, 88.0, 88.0,
# CHECK: 88.0, 88.0, 88.0, 88.0, 96.0, 96.0, 96.0, 96.0,
# CHECK: 96.0, 96.0, 96.0, 96.0, 96.0, 104.0, 104.0, 104.0,
# CHECK: 104.0, 104.0, 104.0, 104.0, 112.0, 112.0, 112.0, 112.0,
# CHECK: 112.0, 112.0, 112.0, 112.0, 112.0, 120.0, 120.0, 120.0,
# CHECK: 120.0, 120.0, 120.0, 120.0, 128.0, 128.0, 128.0, 128.0,
# CHECK: 128.0, 128.0, 128.0, 128.0, 128.0, 128.0, 128.0, 128.0,
# CHECK: 128.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0,
# CHECK: 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0,
# CHECK: 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0,
# CHECK: 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0,
# CHECK: 160.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0,
# CHECK: 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0,
# CHECK: 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0,
# CHECK: 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0,
# CHECK: 192.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0,
# CHECK: 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0,
# CHECK: 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0,
# CHECK: 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0,
# CHECK: 224.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0,
# CHECK: 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0,
# CHECK: 256.0, 256.0, 256.0, 256.0, 256.0, 256.0, 256.0, 256.0,
fn test_simd_f32_to_e4m3():
    print("== test_simd_f32_to_e4m3")

    comptime M = 512
    var f32_simd = SIMD[DType.float32, M](0.0)

    for i in range(M):
        f32_simd[i] = i - 256

    f32_casted_e4m3 = f32_simd.cast[DType.float8_e4m3fn]()

    for i in range(64):
        for j in range(8):
            print(f32_casted_e4m3[i * 8 + j], end=", ")
        print("")


fn test_simd_float8[
    dtype: DType,
    size: Int,
    target: DType,
](x: SIMD[dtype, size]):
    var x_casted = x.cast[target]()

    comptime M = 32
    comptime N = size // M
    for i in range(M):
        for j in range(N):
            print(x_casted[i * N + j], end=", ")
        print("")


# CHECK-LABEL: test_simd_e4m3_to_f16_ptx_path
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_simd_e4m3_to_f16_ptx_path(ctx: DeviceContext) raises:
    print("== test_simd_e4m3_to_f16_ptx_path")

    comptime M = 256
    var e4m3_simd = SIMD[DType.float8_e4m3fn, M](0.0)
    for i in range(M):
        e4m3_simd[i] = bitcast[DType.float8_e4m3fn](UInt8(i))

    comptime kernel = test_simd_float8[DType.float8_e4m3fn, M, DType.float16]
    ctx.enqueue_function_experimental[kernel](
        e4m3_simd, grid_dim=1, block_dim=1
    )
    ctx.synchronize()


# CHECK-LABEL: test_simd_e4m3_to_f32_ptx_path
# CHECK: 0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
# CHECK: 0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
# CHECK: 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
# CHECK: 0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
# CHECK: 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
# CHECK: 0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
# CHECK: 0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
# CHECK: 1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
# CHECK: 2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
# CHECK: 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
# CHECK: 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
# CHECK: 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
# CHECK: 128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
# CHECK: 256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, nan,
# CHECK: -0.0, -0.001953125, -0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875, -0.013671875,
# CHECK: -0.015625, -0.017578125, -0.01953125, -0.021484375, -0.0234375, -0.025390625, -0.02734375, -0.029296875,
# CHECK: -0.03125, -0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125, -0.0546875, -0.05859375,
# CHECK: -0.0625, -0.0703125, -0.078125, -0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875,
# CHECK: -0.125, -0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875, -0.234375,
# CHECK: -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625, -0.4375, -0.46875,
# CHECK: -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125, -0.875, -0.9375,
# CHECK: -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75, -1.875,
# CHECK: -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75,
# CHECK: -4.0, -4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5,
# CHECK: -8.0, -9.0, -10.0, -11.0, -12.0, -13.0, -14.0, -15.0,
# CHECK: -16.0, -18.0, -20.0, -22.0, -24.0, -26.0, -28.0, -30.0,
# CHECK: -32.0, -36.0, -40.0, -44.0, -48.0, -52.0, -56.0, -60.0,
# CHECK: -64.0, -72.0, -80.0, -88.0, -96.0, -104.0, -112.0, -120.0,
# CHECK: -128.0, -144.0, -160.0, -176.0, -192.0, -208.0, -224.0, -240.0,
# CHECK: -256.0, -288.0, -320.0, -352.0, -384.0, -416.0, -448.0, nan,
fn test_simd_e4m3_to_f32_ptx_path(ctx: DeviceContext) raises:
    print("== test_simd_e4m3_to_f32_ptx_path")

    comptime M = 256
    var e4m3_simd = SIMD[DType.float8_e4m3fn, M](0.0)
    for i in range(M):
        e4m3_simd[i] = bitcast[DType.float8_e4m3fn](UInt8(i))

    comptime kernel = test_simd_float8[DType.float8_e4m3fn, M, DType.float32]
    ctx.enqueue_function_experimental[kernel](
        e4m3_simd, grid_dim=1, block_dim=1
    )
    ctx.synchronize()


fn test_simd_float32[
    size: Int,
    target: DType,
](x: SIMD[DType.float32, size]):
    var x_casted = x.cast[target]()

    comptime M = 64
    comptime N = size // M
    for i in range(M):
        for j in range(N):
            print(x_casted[i * N + j], end=", ")
        print("")


# CHECK-LABEL: test_simd_f32_to_e4m3_ptx_path
# CHECK: -256.0, -256.0, -256.0, -256.0, -256.0, -256.0, -256.0, -256.0,
# CHECK: -256.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0,
# CHECK: -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0, -240.0,
# CHECK: -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0,
# CHECK: -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0, -224.0,
# CHECK: -224.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0,
# CHECK: -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0, -208.0,
# CHECK: -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0,
# CHECK: -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0, -192.0,
# CHECK: -192.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0,
# CHECK: -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0, -176.0,
# CHECK: -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0,
# CHECK: -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0, -160.0,
# CHECK: -160.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0,
# CHECK: -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0, -144.0,
# CHECK: -128.0, -128.0, -128.0, -128.0, -128.0, -128.0, -128.0, -128.0,
# CHECK: -128.0, -128.0, -128.0, -128.0, -128.0, -120.0, -120.0, -120.0,
# CHECK: -120.0, -120.0, -120.0, -120.0, -112.0, -112.0, -112.0, -112.0,
# CHECK: -112.0, -112.0, -112.0, -112.0, -112.0, -104.0, -104.0, -104.0,
# CHECK: -104.0, -104.0, -104.0, -104.0, -96.0, -96.0, -96.0, -96.0,
# CHECK: -96.0, -96.0, -96.0, -96.0, -96.0, -88.0, -88.0, -88.0,
# CHECK: -88.0, -88.0, -88.0, -88.0, -80.0, -80.0, -80.0, -80.0,
# CHECK: -80.0, -80.0, -80.0, -80.0, -80.0, -72.0, -72.0, -72.0,
# CHECK: -72.0, -72.0, -72.0, -72.0, -64.0, -64.0, -64.0, -64.0,
# CHECK: -64.0, -64.0, -64.0, -60.0, -60.0, -60.0, -56.0, -56.0,
# CHECK: -56.0, -56.0, -56.0, -52.0, -52.0, -52.0, -48.0, -48.0,
# CHECK: -48.0, -48.0, -48.0, -44.0, -44.0, -44.0, -40.0, -40.0,
# CHECK: -40.0, -40.0, -40.0, -36.0, -36.0, -36.0, -32.0, -32.0,
# CHECK: -32.0, -32.0, -30.0, -28.0, -28.0, -28.0, -26.0, -24.0,
# CHECK: -24.0, -24.0, -22.0, -20.0, -20.0, -20.0, -18.0, -16.0,
# CHECK: -16.0, -15.0, -14.0, -13.0, -12.0, -11.0, -10.0, -9.0,
# CHECK: -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0,
# CHECK: 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
# CHECK: 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
# CHECK: 16.0, 16.0, 18.0, 20.0, 20.0, 20.0, 22.0, 24.0,
# CHECK: 24.0, 24.0, 26.0, 28.0, 28.0, 28.0, 30.0, 32.0,
# CHECK: 32.0, 32.0, 32.0, 36.0, 36.0, 36.0, 40.0, 40.0,
# CHECK: 40.0, 40.0, 40.0, 44.0, 44.0, 44.0, 48.0, 48.0,
# CHECK: 48.0, 48.0, 48.0, 52.0, 52.0, 52.0, 56.0, 56.0,
# CHECK: 56.0, 56.0, 56.0, 60.0, 60.0, 60.0, 64.0, 64.0,
# CHECK: 64.0, 64.0, 64.0, 64.0, 64.0, 72.0, 72.0, 72.0,
# CHECK: 72.0, 72.0, 72.0, 72.0, 80.0, 80.0, 80.0, 80.0,
# CHECK: 80.0, 80.0, 80.0, 80.0, 80.0, 88.0, 88.0, 88.0,
# CHECK: 88.0, 88.0, 88.0, 88.0, 96.0, 96.0, 96.0, 96.0,
# CHECK: 96.0, 96.0, 96.0, 96.0, 96.0, 104.0, 104.0, 104.0,
# CHECK: 104.0, 104.0, 104.0, 104.0, 112.0, 112.0, 112.0, 112.0,
# CHECK: 112.0, 112.0, 112.0, 112.0, 112.0, 120.0, 120.0, 120.0,
# CHECK: 120.0, 120.0, 120.0, 120.0, 128.0, 128.0, 128.0, 128.0,
# CHECK: 128.0, 128.0, 128.0, 128.0, 128.0, 128.0, 128.0, 128.0,
# CHECK: 128.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0,
# CHECK: 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0, 144.0,
# CHECK: 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0,
# CHECK: 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0, 160.0,
# CHECK: 160.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0,
# CHECK: 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0, 176.0,
# CHECK: 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0,
# CHECK: 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0, 192.0,
# CHECK: 192.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0,
# CHECK: 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0, 208.0,
# CHECK: 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0,
# CHECK: 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0, 224.0,
# CHECK: 224.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0,
# CHECK: 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0, 240.0,
# CHECK: 256.0, 256.0, 256.0, 256.0, 256.0, 256.0, 256.0, 256.0,
fn test_simd_f32_to_e4m3_ptx_path(ctx: DeviceContext) raises:
    print("== test_simd_f32_to_e4m3_ptx_path")

    comptime M = 512
    var f32_simd = SIMD[DType.float32, M](0.0)
    for i in range(M):
        f32_simd[i] = i - 256

    comptime kernel = test_simd_float32[M, DType.float8_e4m3fn]
    ctx.enqueue_function_experimental[kernel](f32_simd, grid_dim=1, block_dim=1)
    ctx.synchronize()


def main():
    test_e4m3fn_initialization()
    test_simd_e4m3_to_f32()
    test_simd_e4m3_to_f16()
    test_simd_e4m3_to_bf16()

    test_simd_f32_to_e4m3()

    with DeviceContext() as ctx:
        test_simd_e4m3_to_f16_ptx_path(ctx)
        test_simd_e4m3_to_f32_ptx_path(ctx)
        test_simd_f32_to_e4m3_ptx_path(ctx)
